{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch_scatter\n",
    "+ scatter\n",
    "+ segment_coo\n",
    "+ segment_csr  \n",
    "\n",
    "文档:[https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## scatter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将原始输入按照index向量分组处理，函数会自适应输入的维度，在进行分组的维度上，缩小为分组数大小的尺寸   \n",
    "方法选项{\"sum\"(default),\"mul\",\"mean\",\"min\",\"max\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-11T08:48:23.168927Z",
     "start_time": "2021-06-11T08:48:22.464651Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0维上分4组\n",
      "torch.Size([4, 6])\n",
      "tensor([[2., 2., 2., 2., 2., 2.],\n",
      "        [3., 3., 3., 3., 3., 3.],\n",
      "        [1., 1., 1., 1., 1., 1.],\n",
      "        [4., 4., 4., 4., 4., 4.]])\n",
      "第1维上分3组\n",
      "torch.Size([10, 3])\n",
      "tensor([[2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.],\n",
      "        [2., 3., 1.]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch_scatter import scatter\n",
    "\n",
    "src = torch.ones(10, 6)\n",
    "index_0 = torch.tensor([0, 1, 0, 1, 2, 1, 3, 3, 3, 3])\n",
    "index_1 = torch.tensor([0, 1, 0, 1, 2, 1])\n",
    "\n",
    "print(\"第0维上分4组\")\n",
    "out_0 = scatter(src, index_0, dim=0, reduce=\"sum\")\n",
    "print(out_0.shape)\n",
    "print(out_0)\n",
    "\n",
    "print(\"第1维上分3组\")\n",
    "out_1 = scatter(src, index_1, dim=1, reduce=\"sum\")\n",
    "print(out_1.shape)\n",
    "print(out_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "p.s. 分组个数根据索引最大值确定，如果中间缺值，则会视为该组为空，返回0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-07T12:06:59.468519Z",
     "start_time": "2021-06-07T12:06:59.458525Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 4])\n",
      "tensor([[3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.],\n",
      "        [3., 0., 1., 2.]])\n"
     ]
    }
   ],
   "source": [
    "src = torch.ones(10, 6)\n",
    "index_1 = torch.tensor([0, 2, 0, 3, 0, 3])\n",
    "out_1 = scatter(src, index_1, dim=1, reduce=\"sum\")\n",
    "print(out_1.shape)\n",
    "print(out_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-09T11:24:08.388977Z",
     "start_time": "2021-06-09T11:24:08.385119Z"
    },
    "heading_collapsed": true
   },
   "source": [
    "## COO 和 CSR 是稀疏矩阵存取格式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## segment_coo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "src维度：$(x_0,...,x_{m-1},x_m,...,x_n)$  \n",
    "index维度：$(x_0,...,x_{m-1},x_m)$  \n",
    "out维度：$(x_0,...,x_{m-1},y,x_m,...,x_n)$  \n",
    "对index将分组维度前的维度，使用view置1，要分组的维度置-1\n",
    "\n",
    "index的值必须是递增的，不递增的情况下结果出错，或程序崩溃。该函数按照递增关系求分组，连续几个相同的索引值会被分到一组，如下例中的处理过程为  \n",
    "[0]  \n",
    "[0] [1 1 1 1]  \n",
    "[0] [1 1 1 1] [2]  \n",
    "\n",
    "使用segment_coo一般比scatter快"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-07T12:12:10.634230Z",
     "start_time": "2021-06-07T12:12:10.545203Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 3])\n",
      "tensor([[1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.],\n",
      "        [1., 4., 1.]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch_scatter import segment_coo\n",
    "\n",
    "src = torch.ones(10, 6)\n",
    "index = torch.tensor([0, 1, 1, 1, 1, 2])\n",
    "index = index.view(1, -1)  # Broadcasting in the first and last dim.\n",
    "\n",
    "out = segment_coo(src, index, reduce=\"sum\")\n",
    "\n",
    "print(out.size())\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-07T09:03:24.418161Z",
     "start_time": "2021-06-07T09:03:24.414120Z"
    }
   },
   "source": [
    "## segment_csr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-07T12:02:02.433476Z",
     "start_time": "2021-06-07T12:02:02.424336Z"
    }
   },
   "source": [
    "src维度：$(x_0,...,x_{m-1},x_m,...,x_n)$  \n",
    "index维度：$(x_0,...,x_{m-1},y)$  \n",
    "out维度：$(x_0,...,x_{m-1},y-1,x_m,...,x_n)$\n",
    "\n",
    "对index将分组维度前的维度，使用view置1，分组维度置-1\n",
    "该函数将索引值视为分组起点，在下例中为  \n",
    "[0] [1 -] [3 -] [5] [6]  \n",
    "\n",
    "segment_csr相比来说是最快的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-07T12:12:27.601893Z",
     "start_time": "2021-06-07T12:12:27.590132Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 4])\n",
      "tensor([[1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.],\n",
      "        [1., 2., 2., 1.]])\n"
     ]
    }
   ],
   "source": [
    "from torch_scatter import segment_csr\n",
    "\n",
    "src = torch.ones(10, 6)\n",
    "indptr = torch.tensor([0, 1, 3, 5, 6])\n",
    "indptr = indptr.view(1, -1)  # Broadcasting in the first and last dim.\n",
    "\n",
    "out = segment_csr(src, indptr, reduce=\"sum\")\n",
    "\n",
    "print(out.size())\n",
    "print(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "song",
   "language": "python",
   "name": "song"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
